library(here)
library(readr)
library(wru)
library(dplyr)
library(pROC)
library(reshape)
library(ppmf)
library(ggplot2)

bins <- 20

##################################################
##         read in all the data sources         ##
##################################################

# load helpers
source(here("R/00_custom_functions.R"))

# get the census data
load(here("data/wru-NC/censusData.rData"))

# paths to read in the voter file data
dir.vf <- here("data/wru-NC/NC.csv")

# path to read in the block data
dir.blockData <- here("data/wru-NC/Block Baselines/")

# name probabilities
load(here("data/wru-NC/first.RData"))
load(here("data/wru-NC/mid.RData"))
load(here("data/wru-NC/last.RData"))

state <- 'NC'

##################################################
##         function to make predictions         ##
##################################################

mapFunction <- function(data) {
	mapPred <- ifelse(data[[grep('whi', names(data))]] > pmax(data[[grep('bla', names(data))]], data[[grep('his', names(data))]],
                                                       data[[grep('asi', names(data))]], data[[grep('oth', names(data))]]), 'whi',
                      ifelse(data[[grep('bla', names(data))]] > pmax(data[[grep('whi', names(data))]], data[[grep('his', names(data))]],
                                                       data[[grep('asi', names(data))]], data[[grep('oth', names(data))]]), 'bla',
                      ifelse(data[[grep('his', names(data))]] > pmax(data[[grep('whi', names(data))]], data[[grep('bla', names(data))]],
                                                       data[[grep('asi', names(data))]], data[[grep('oth', names(data))]]), 'his',
                      ifelse(data[[grep('asi', names(data))]] > pmax(data[[grep('whi', names(data))]], data[[grep('bla', names(data))]],
                                                       data[[grep('his', names(data))]], data[[grep('oth', names(data))]]), 'asi',
                                                  'oth'))))
	return(mapPred)
}

predictAndEvaluate <- function(blockData, testSet, name, state) {

	# build the geoBaselines
	geoBaselines <- left_join(testSet, blockData,
		by = c("Voters_FIPS" = "county", "Residence_Addresses_CensusTract" = "tract",
		       "Residence_Addresses_CensusBlock" = "block"))
	geoBaselines$ID <- 1:nrow(geoBaselines)

	# case everything to upper case
	geoBaselines$Voters_FirstName <- toupper(geoBaselines$Voters_FirstName)
	geoBaselines$Voters_MiddleName <- toupper(geoBaselines$Voters_MiddleName)
	geoBaselines$Voters_LastName <- toupper(geoBaselines$Voters_LastName)

	# make the surname predictions
	allData <- left_join(geoBaselines, dict_last, by = c("Voters_LastName" = "last_name"), suffix = c(".geo", ".last"))
	allData[is.na(allData$p_whi.last), grep("last", colnames(allData))] <- 1 		# deal with missing names

	preds.last <- allData[,names(allData)[grepl('geo', names(allData))]] * allData[,names(allData)[grepl('last', names(allData))]]
	preds.last.normalized <- data.frame(apply(preds.last, 2, FUN = function(x) {x/rowSums(preds.last)}))
	colnames(preds.last.normalized) <- paste(gsub('p_|\\.geo', '', colnames(preds.last.normalized)), "last", sep = "_")

	# make the census first name predictions
	names(dict_first)[grep("p_", names(dict_first))] <- paste(names(dict_first)[grep("p_", names(dict_first))], ".first", sep = '')
	allData <- left_join(allData, dict_first, by = c("Voters_FirstName" = "first_name"))
	allData[is.na(allData$p_whi.first), grep("first", colnames(allData))] <- 1 	# deal with missing names

	preds.first <- preds.last * allData[,names(allData)[grepl('first', names(allData))]]
	preds.first.normalized <- data.frame(apply(preds.first, 2, FUN = function(x) {x/rowSums(preds.first)}))
	colnames(preds.first.normalized) <- gsub("last", "first", names(preds.last.normalized))

	# make the middle name preds
	names(dict_middle)[grep("p_", names(dict_middle))] <- paste(names(dict_middle)[grep("p_", names(dict_middle))], ".middle", sep = '')
	allData <- left_join(allData, dict_middle, by = c("Voters_MiddleName" = "middle_name"))
	allData[is.na(allData$p_whi.middle), grep("middle", colnames(allData))] <- 1 # deal with missing names

	preds.middle <- preds.first * allData[,names(allData)[grepl('middle', names(allData))]]
	preds.middle.normalized <- data.frame(apply(preds.middle, 2, FUN = function(x) {x/rowSums(preds.middle)}))
	colnames(preds.middle.normalized) <- gsub("last", "mid", names(preds.last.normalized))

	# organize the data
	true_race <- ifelse(allData$CountyEthnic_Description == "White Self Reported", "whi",
                             ifelse(allData$CountyEthnic_Description == "African or Af-Am Self Reported", "bla",
                             ifelse(allData$CountyEthnic_Description == "Hispanic", "his",
                             ifelse(allData$CountyEthnic_Description == "East Asian" | allData$CountyEthnic_Description == "Korean", "asi", "oth"))))

	evalData <- data.frame(allData, true_race, preds.last.normalized, preds.first.normalized, preds.middle.normalized)

	# roc aucs
	methods <- c("last", "first", "mid")
	rocVals <- sapply(methods, FUN = function(name) {
	  sapply(c("whi", "bla", "his", "asi", "oth"), FUN = function(eth) {

		print(ret <- as.numeric(auc(roc(evalData$true_race == eth, evalData[[paste(eth, name, sep = "_")]],
			direction = "<", levels = c(FALSE, TRUE)))))
		ret
  		})
	})

	# calibration plots
	binDelimiters <- seq(from = 0, to = 1, length.out = bins + 1)
	calibration <- lapply(methods, FUN = function(name) {
	  sapply(c("whi", "bla", "his", "asi", "oth"), FUN = function(eth) {
	    sapply(1:bins, FUN = function(i) {
	      cat(i)
	      mean(evalData$true_race[binDelimiters[i] < evalData[paste(eth, name, sep = "_")] &
		                      evalData[paste(eth, name, sep = "_")] < binDelimiters[i + 1]] == eth, na.rm = TRUE)
	    })
	  })
	})

	# accuracy statistics
	evalData$map_last <- mapFunction(evalData[,grep("_last", names(evalData))])
	evalData$map_first <- mapFunction(evalData[,grep("_first", names(evalData))])
	evalData$map_mid <- mapFunction(evalData[,grep("_mid", names(evalData))])

	# get the success measures table
	mapVals <- c("map_last", "map_first", "map_mid")
	errorRates <- 1 - sapply(mapVals, FUN = function(pred) {
		mean(evalData[[pred]] == evalData$true_race, na.rm = TRUE)
	})

	errorRates.byRace <- sapply(mapVals, FUN = function(pred) {
	  sapply(c("whi", "bla", "his", "asi", "oth"), FUN = function(r) {

	    FPR <- 1 - mean(evalData[[pred]][evalData[[pred]] == r] == evalData$true_race[evalData[[pred]] == r], na.rm = TRUE)
	    FNR <- 1 - mean(evalData[[pred]][evalData$true_race == r] == evalData$true_race[evalData$true_race == r], na.rm = TRUE)
	    c(FNR, FPR)
	  })
	})

	finalTable <- data.frame(round(rbind(errorRates, errorRates.byRace), 3))
	finalTable$race <- c("Overall Error Rate", "W", "W", "B", "B", "H", "H", "A", "A", "O", "O")
	finalTable$Data <- c("", "FNR", "FPR", "FNR", "FPR", "FNR", "FPR", "FNR", "FPR", "FNR", "FPR")
	rownames(finalTable) <- c()
	finalTable <- finalTable[, c(ncol(finalTable) - 1, ncol(finalTable), 1:(ncol(finalTable) - 2))]

	return(list(rocVals = rocVals,
		    calibration = calibration,
		    finalTable = finalTable))

}

####################################################
##             data reading functions             ##
####################################################

# true block data function
trueBlockData <- function(state) {

	# read in the true block data
	blockData.true <- censusData[[state]]$block

	blockData.true$p_whi <- blockData.true$P005003/rowSums(blockData.true[,grep("P00", names(blockData.true))])
	blockData.true$p_bla <- blockData.true$P005004/rowSums(blockData.true[,grep("P00", names(blockData.true))])
	blockData.true$p_his <- blockData.true$P005010/rowSums(blockData.true[,grep("P00", names(blockData.true))])
	blockData.true$p_asi <- blockData.true$P005006/rowSums(blockData.true[,grep("P00", names(blockData.true))])
	blockData.true$p_oth <- (blockData.true$P005005 + blockData.true$P005007 + blockData.true$P005008 + blockData.true$P005009)/
			rowSums(blockData.true[,grep("P00", names(blockData.true))])

	blockData.true <- blockData.true[, c("state", "county", "tract", "block",
				   "p_whi", "p_bla", "p_his", "p_asi", "p_oth")]
	return(blockData.true)
}

# dp block data function
dpData <- function(state) {

	# read in the dp 4 block data
	path.4 <- paste(dir.blockData, tolower(state), '_blockBaselines_4.csv', sep = '')
	blockData.dp4_raw <- read.csv(path.4)[,-1]
	blockData.dp4 <- data.frame(state = state,
			    county = substr(blockData.dp4_raw$block, 3, 5),
			    tract = substr(blockData.dp4_raw$block, 6, 11),
			    block = substr(blockData.dp4_raw$block, 12, 15),
			    p_whi = blockData.dp4_raw$r_whi,
			    p_bla = blockData.dp4_raw$r_bla,
			    p_his = blockData.dp4_raw$r_his,
			    p_asi = blockData.dp4_raw$r_asi,
			    p_oth = blockData.dp4_raw$r_oth)

	# read in the dp12 block data
	path.12 <- paste(dir.blockData, tolower(state), '_blockBaselines_12.csv', sep = '')
	blockData.dp12_raw <- read.csv(path.12)[,-1]
	blockData.dp12 <- data.frame(state = state,
			    county = substr(blockData.dp12_raw$block, 3, 5),
			    tract = substr(blockData.dp12_raw$block, 6, 11),
			    block = substr(blockData.dp12_raw$block, 12, 15),
			    p_whi = blockData.dp12_raw$r_whi,
			    p_bla = blockData.dp12_raw$r_bla,
			    p_his = blockData.dp12_raw$r_his,
			    p_asi = blockData.dp12_raw$r_asi,
			    p_oth = blockData.dp12_raw$r_oth)

	# return the data
	return(list(blockData.dp4 = blockData.dp4, blockData.dp12 = blockData.dp12))
}

#################################################
##          compute the summary stats          ##
#################################################

# get the test set
testSet <- read_csv(dir.vf)
testSet$Residence_Addresses_CensusBlock <- as.character(testSet$Residence_Addresses_CensusBlock)
testSet$Residence_Addresses_CensusTract <- as.character(testSet$Residence_Addresses_CensusTract)

# read in the block baselines
blockData.true <- trueBlockData(state)
blockData.dp <- dpData(state)
blockData.dp4 <- blockData.dp$blockData.dp4
blockData.dp12 <- blockData.dp$blockData.dp12

# build the evaluations
trueStats <- predictAndEvaluate(blockData.true, testSet, "true", state)
dp4Stats <- predictAndEvaluate(blockData.dp4, testSet, "dp4", state)
dp12Stats <- predictAndEvaluate(blockData.dp12, testSet, "dp12", state)

############################################
##              make visuals              ##
############################################

# make the ROC plots
longData <- data.frame(rbind(cbind(melt(trueStats$rocVals), method = "Census 2010"),
                             cbind(melt(dp12Stats$rocVals), method = "DAS-12.2"),
                             cbind(melt(dp4Stats$rocVals), method = "DAS-4.5")))
                             # cbind(melt(dp19Stats$rocVals), method = "DAS-19.6")))
names(longData) <- c("Ethnicity", "Names_Used", "ROC_AUC", "Method")
longData$ROC_AUC <- round(longData$ROC_AUC, 3)

longData$Names_Used <- as.character(longData$Names_Used)
longData$Names_Used[longData$Names_Used == 'last'] <- 'Last Names Only'
longData$Names_Used[longData$Names_Used == 'first'] <- 'Last + First Names'
longData$Names_Used[longData$Names_Used == 'mid'] <- 'Last + First +\nMiddle Names'

longData$Names_Used <- factor(longData$Names_Used, levels = c('Last Names Only',
                                                              'Last + First Names', 'Last + First +\nMiddle Names'))
longData$Ethnicity <- factor(longData$Ethnicity, levels = c('whi', 'bla', 'his', 'asi', 'oth'))
levels(longData$Ethnicity) <- c('White', 'Black', 'Hispanic', 'Asian', 'Other')
longData$Method <- factor(longData$Method, levels = c('Census 2010', 'DAS-12.2', 'DAS-4.5'))

# plot the ROC values for 'Census 2010', 'DAS-4.5', 'DAS-12.2'
longData %>%
  ggplot(aes(x = Method, y = ROC_AUC)) + aes(fill = Method) + theme_ppmf() +
  geom_bar(stat = 'identity') +
  facet_grid(rows = vars(Names_Used), cols = vars(Ethnicity)) +
  coord_cartesian(ylim = c(0.65, 1)) +
  geom_text(aes(label= format(ROC_AUC, 3)),
            position=position_dodge(width=0.3),
            vjust=-0.25,
            size = 2.5,
            family = "serif") +
  scale_fill_manual(values = PAL_DAS) +
  scale_y_continuous(labels = scales::percent_format(accuracy = 1), expand = expansion()) +
    scale_x_discrete(labels = NULL, breaks = NULL) +
  labs(fill = "Geographic Prior", y = "Area Under the Receiver Operating Characteristic Curve (AUROC)", x = NULL)
ggsave("figs/wru_nc_roc_plots.pdf", width = 10*0.75, height = 5*0.75)

# plot for 'DAS-19.6'
longData %>%
  filter(Method == 'DAS-19.6') %>%
  ggplot(aes(x = Ethnicity, y = ROC_AUC)) + aes(fill = Ethnicity) + theme_ppmf() +
  theme(axis.text.x=element_blank(), axis.ticks.x=element_blank()) +
  geom_bar(stat = 'identity') +
  facet_grid(cols = vars(Names_Used)) +
  coord_cartesian(ylim = c(0.65, 1)) +
  geom_text(aes(label=ROC_AUC), position=position_dodge(width=0.3), vjust=-0.25, size = 3.0) +
  scale_fill_manual(values = c(PAL_race, Asian = "#3E6E94", Other = "lightgray")) +
  scale_y_continuous(labels = scales::percent_format(accuracy = 1)) +
  labs(fill = "Race/\nEthnicity", y = "AUROC", x = '')
setwd("~/Documents/GitHub/das-evaluation")
ggsave("figs/wru_nc_roc_plots_19_only.pdf", width = 7.5, height = 3)


# print the tables with statistics
print(trueStats$finalTable)
print(dp4Stats$finalTable)
print(dp12Stats$finalTable)
print(dp19Stats$finalTable)

